from src.db_redis import RedisClient
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.neighbors import KNeighborsClassifier
from sklearn.externals import joblib


class Regression:
    def __init__(self):
        self.dao = RedisClient()

    def train(self, city):
        # 从redis中取出相应城市的数据
        base_data = self.dao.list_all(city)
        if base_data is None or len(base_data) < 100:
            return
        date = []
        top_data = []
        low_data = []
        weather_data = []

        # 取出相应的训练数据
        for i in base_data:
            item = eval(i)
            # 日期由月和日组成一个浮点数
            date.append(item[1] + (item[2] / 100))
            top_data.append(item[3])
            low_data.append(item[4])
            weather_data.append(item[5])

        # 天气数据数值化
        set_weather = set(weather_data)
        map_weather = {}
        for index, weather in enumerate(set_weather):
            map_weather[weather] = index

        for index, weather in enumerate(weather_data):
            weather_data[index] = map_weather[weather]

        # 存储map进redis
        new_map_weather = dict(zip(map_weather.values(), map_weather.keys()))
        self.dao.save_map("{}map".format(city), new_map_weather)

        date = np.asarray(date).reshape(len(date), 1)

        # 多项式回归的数据处理
        poly_reg = PolynomialFeatures(degree=2)
        base_ploy = poly_reg.fit_transform(date)

        # 建立训练模型
        top_clf = LinearRegression()
        low_clf = LinearRegression()
        weather_clf = KNeighborsClassifier()

        # 训练数据
        top_clf.fit(base_ploy, top_data)
        low_clf.fit(base_ploy, low_data)
        weather_clf.fit(date, weather_data)

        # 存储训练数据
        joblib.dump(top_clf, "model/{}top_clf.pkl".format(city))
        joblib.dump(low_clf, "model/{}low_clf.pkl".format(city))
        joblib.dump(weather_clf, "model/{}weather_clf.pkl".format(city))

    def predict(self, city, month, day):
        top_clf = joblib.load("model/{}top_clf.pkl".format(city))
        low_clf = joblib.load("model/{}low_clf.pkl".format(city))
        weather_clf = joblib.load("model/{}weather_clf.pkl".format(city))

        test_date = np.asarray([month + (day / 100)]).reshape(-1, 1)

        # 多项式回归的数据处理
        poly_reg = PolynomialFeatures(degree=2)
        base_ploy = poly_reg.fit_transform(test_date)

        top = int(round(top_clf.predict(base_ploy)[0]))
        low = int(round(low_clf.predict(base_ploy)[0]))
        weather = weather_clf.predict(test_date)[0]
        weather = self.dao.get_map("{}map".format(city), str(weather))
        print(city, top, low, weather)
        return top, low, weather

    # 数据测试和可视化
    def test(self):
        X_data = []
        Y_data = []
        low_data = []
        weather_data = []
        l = self.dao.list_all("萍乡")
        for i in l:
            b = eval(i)
            X_data.append(b[1] + (b[2] / 10))
            Y_data.append(b[3])
            low_data.append(b[4])
            weather_data.append(b[5])

        set_weather = set(weather_data)
        map_weather = {}
        for index, weather in enumerate(set_weather):
            map_weather[weather] = index

        for index, weather in enumerate(weather_data):
            weather_data[index] = map_weather[weather]

        print(weather_data)

        X_data = np.asarray(X_data).reshape(len(X_data), 1)

        poly_reg = PolynomialFeatures(degree=2)
        X_ploy = poly_reg.fit_transform(X_data)

        clf = LinearRegression()
        low_clf = LinearRegression()
        weather_clf = KNeighborsClassifier()

        clf.fit(X_ploy, Y_data)
        low_clf.fit(X_ploy, low_data)
        weather_clf.fit(X_data, weather_data)

        test_x = np.zeros((len(Y_data), 1))
        test_x.astype('int64')
        for i in range(len(Y_data)):
            test_x[i] = np.array([5.20])
        poly = PolynomialFeatures(degree=2)
        xx = poly.fit_transform(test_x)
        y_p = clf.predict(xx)
        low_result = low_clf.predict(X_ploy)
        weather_result = weather_clf.predict(np.asarray([5.20]).reshape(-1, 1))
        weather_result2 = weather_clf.predict(X_data)
        y_w = clf.predict(poly_reg.fit_transform(np.asarray([5.20]).reshape(-1, 1)))
        print(y_p[0])
        print(y_w)

        # plt.scatter(X_data, Y_data,c="r")
        # plt.scatter(X_data, y_p)
        # plt.scatter(X_data, low_result,c="g")
        plt.scatter(X_data, weather_data, c="r")
        plt.scatter(X_data, weather_result2)
        plt.show()
